package fsm

import (
	"encoding/json"
	"fmt"
	"sync"
)

type EventType string
type StateType string
type ReturnVal int
type ArgsType map[string]any

type EventContext struct {
	Event EventType
	Args  ArgsType
}

// newEvtCtxt将在下一个状态被处理
type Callback func(state StateType, evtCtxt *EventContext) (newEvtCtxt *EventContext, rv ReturnVal)
type LogFunc func(format string, a ...any)

type FsmEvent struct {
	HandleFunc Callback
	TransTable map[ReturnVal]StateType
}

type FsmState struct {
	DefaultHandle *FsmEvent
	EventTable    map[EventType]*FsmEvent
}

type Fsm struct {
	fsmId      string
	initState  StateType
	stateTable map[StateType]*FsmState
	logFunc    LogFunc
}

type FsmInst struct {
	fsmInstId string
	currState StateType
	fsmId     string
}

// MarshalJSON 自定义序列化方法
func (inst FsmInst) MarshalJSON() ([]byte, error) {
	type Alias FsmInst
	return json.Marshal(&struct {
		*Alias
		FsmInstId string    `json:"fsmInstId"`
		CurrState StateType `json:"currState"`
		FsmId     string    `json:"fsmId"`
	}{
		Alias:     (*Alias)(&inst),
		FsmInstId: inst.fsmInstId,
		CurrState: inst.currState,
		FsmId:     inst.fsmId,
	})
}

// UnmarshalJSON 自定义反序列化方法
func (inst *FsmInst) UnmarshalJSON(data []byte) error {
	type Alias FsmInst
	aux := &struct {
		*Alias
		FsmInstId string    `json:"fsmInstId"`
		CurrState StateType `json:"currState"`
		FsmId     string    `json:"fsmId"`
	}{
		Alias:     (*Alias)(inst),
		FsmInstId: inst.fsmInstId,
		CurrState: inst.currState,
		FsmId:     inst.fsmId,
	}
	if err := json.Unmarshal(data, &aux); err != nil {
		return err
	}
	inst.fsmInstId = aux.FsmInstId
	inst.currState = aux.CurrState
	inst.fsmId = aux.FsmId
	return nil
}

var fsmId2FsmMap sync.Map // fsmId:*Fsm

func NewFsm(fsmId string, initState StateType, stateTable map[StateType]*FsmState) *Fsm {
	fsm := &Fsm{
		fsmId:      fsmId,
		initState:  initState,
		stateTable: stateTable,
		logFunc:    nil,
	}
	fsmId2FsmMap.Store(fsmId, fsm)
	return fsm
}

func NewFsmInst(sm *Fsm, fsmInstId string) *FsmInst {
	fsmInst := &FsmInst{
		fsmInstId: fsmInstId,
		currState: sm.initState,
		fsmId:     sm.fsmId,
	}
	return fsmInst
}

func (sm *Fsm) SetLogFunc(logFunc LogFunc) {
	sm.logFunc = logFunc
}

func (fsmInst *FsmInst) HandleEvent(evtCtxt *EventContext) error {
	evtCtxtTemp := evtCtxt
	for {
		newEvtCtxt, err := fsmInst.handleEventImpl(evtCtxtTemp)
		if err != nil {
			return err
		}
		if newEvtCtxt == nil {
			break
		}
		evtCtxtTemp = newEvtCtxt
	}
	return nil
}

func (fsmInst *FsmInst) handleEventImpl(evtCtxt *EventContext) (*EventContext, error) {
	val, ok := fsmId2FsmMap.Load(fsmInst.fsmId)
	if !ok {
		return nil, fmt.Errorf("FSM[%s] INST[%s]: Can not find FSM", fsmInst.fsmId, fsmInst.fsmInstId)
	}
	sm := val.(*Fsm)

	currState := fsmInst.currState

	if sm.stateTable == nil {
		return nil, fmt.Errorf("FSM[%s] INST[%s]: state table is nil", sm.fsmId, fsmInst.fsmInstId)
	}

	smState, ok := sm.stateTable[currState]
	if !ok {
		return nil, fmt.Errorf("FSM[%s] INST[%s]: Can not find state[%s]", sm.fsmId, fsmInst.fsmInstId, currState)
	}
	if smState == nil {
		return nil, fmt.Errorf("FSM[%s] INST[%s]: state is nil", sm.fsmId, fsmInst.fsmInstId)
	}

	smEvent, ok := smState.EventTable[evtCtxt.Event]
	if !ok {
		smEvent = smState.DefaultHandle
	}
	if smEvent == nil {
		return nil, fmt.Errorf("FSM[%s] INST[%s] STATE[%s]: event[%s] handle and default handle is nil", sm.fsmId, fsmInst.fsmInstId, currState, evtCtxt.Event)
	}
	if smEvent.HandleFunc == nil {
		return nil, fmt.Errorf("FSM[%s] INST[%s] STATE[%s]: handle func is nil", sm.fsmId, fsmInst.fsmInstId, currState)
	}
	oldEvent := evtCtxt.Event
	newEvtCtxt, rv := smEvent.HandleFunc(currState, evtCtxt)

	newState, ok := smEvent.TransTable[rv]
	if !ok {
		return nil, fmt.Errorf("FSM[%s] INST[%s] STATE[%s]: Can not find new state by callback return value %d",
			sm.fsmId, fsmInst.fsmInstId, currState, rv)
	}
	fsmInst.currState = newState

	if sm.logFunc != nil {
		sm.logFunc("FSM[%s] INST[%s]: %s[%s] --> %s\n", sm.fsmId, fsmInst.fsmInstId, currState, oldEvent, newState)
	}

	return newEvtCtxt, nil
}

func (fsmInst *FsmInst) CurrState() StateType {
	return fsmInst.currState
}

func (fsmInst *FsmInst) ResetState() {
	val, ok := fsmId2FsmMap.Load(fsmInst.fsmId)
	if !ok {
		return
	}
	sm := val.(*Fsm)
	fsmInst.currState = sm.initState
}
